{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "os.chdir('../..')\n",
    "sys.path.insert(1, os.path.join(sys.path[0], '../..'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = [\n",
    "    'models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03/',\n",
    "    'models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03_sup2/',\n",
    "    'models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/',\n",
    "    'models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_sup2/',\n",
    "    'models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/',\n",
    "    'models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1_sup2/',\n",
    "    'models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/',\n",
    "    'models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1_sup2/',\n",
    "    'models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/',\n",
    "    'models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_sup2/',\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import normalized_mutual_info_score\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def compute_nmi(model, K=1251):\n",
    "    embeddings_path = f'{model}/embeddings_vox1_avg.pt'\n",
    "\n",
    "    embeddings = torch.load(embeddings_path)\n",
    "\n",
    "    if len(embeddings) != 153516:\n",
    "        print('Invalid embeddings', model, len(embeddings))\n",
    "\n",
    "    X = np.concatenate(list(embeddings.values()))\n",
    "    y_speaker = [y.split('/')[-3] for y in embeddings.keys()]\n",
    "    y_video = [y.split('/')[-2] for y in embeddings.keys()]\n",
    "\n",
    "    kmeans = KMeans(n_clusters=K, init='random', algorithm='lloyd', random_state=0).fit(X)\n",
    "\n",
    "    nmi_speaker = normalized_mutual_info_score(y_speaker, kmeans.labels_)\n",
    "    nmi_video = normalized_mutual_info_score(y_video, kmeans.labels_)\n",
    "\n",
    "    return nmi_speaker, nmi_video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in tqdm(MODELS):\n",
    "    nmi_speaker, nmi_video = compute_nmi(model)\n",
    "    print(f'Model: {model} - NMI Speaker: {nmi_speaker} - NMI Video: {nmi_video} - Ratio: {nmi_speaker / nmi_video}')\n",
    "    with open(f'{model}/nmi.json', 'w') as f:\n",
    "        json.dump({\n",
    "            \"vox1_nmi_speaker\": nmi_speaker,\n",
    "            \"vox1_nmi_video\": nmi_video\n",
    "        }, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_py-3.13.3_torch-2.7.1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
